import torch as t
import tqdm
import skimage.io as io
import matplotlib.pyplot as plt
import util


def mse(x,y):
  return t.nn.functional.mse_loss(x,y)

def mssimfast_err(x,y,kern,alpha=1,beta=1,gamma=1):
  mu1 = t.nn.functional.conv2d(
    x.unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)
  mu2 = t.nn.functional.conv2d(
    y.unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)
  var1 = t.nn.functional.conv2d(
    (x**2).unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)-mu1**2
  var2  = t.nn.functional.conv2d(
    (y**2).unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)-mu2**2
  sig12 = t.nn.functional.conv2d(
    (x*y).unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)-mu1*mu2

  sig1=t.sqrt(var1+1e-6)
  sig2=t.sqrt(var2+1e-6)
  c1=0.01**2
  c2=0.03**2
  c3=c2/2
  lumQuality = (2*mu1*mu2+c1)/(mu1**2+mu2**2+c1)
  conQuality = (2*sig1*sig2+c2)/(sig1**2+sig2**2+c2)
  strQuality = (sig12+c3)/(sig1*sig2+c3)
  ssim = (lumQuality**alpha)*(conQuality**beta)*(strQuality**gamma)
  return -t.mean(ssim)

def mssim_err(x,y,kern,alpha=1,beta=1,gamma=1):
  mu1 = t.nn.functional.conv2d(
    x.unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)
  mu2 = t.nn.functional.conv2d(
    y.unsqueeze(0),kern.unsqueeze(0).unsqueeze(0)
  ).squeeze(0)
  var1 = t.zeros_like(mu1)
  var2 = t.zeros_like(mu1)
  sig12 = t.zeros_like(mu1)
  kh=kern.shape[0]
  kw=kern.shape[1]
  ah=mu1.shape[0]
  aw=mu1.shape[1]
  for i in range(kh):
    for j in range(kw):
      xdif = x[i:ah+i,j:aw+j]-mu1
      ydif = y[i:ah+i,j:aw+j]-mu2
      var1+=kern[i,j]*xdif*xdif
      var2+=kern[i,j]*ydif*ydif
      sig12+=kern[i,j]*xdif*ydif
  sig1=t.sqrt(var1+1e-6)
  sig2=t.sqrt(var2+1e-6)
  c1=0.01**2
  c2=0.03**2
  c3=c2/2
  lumQuality = (2*mu1*mu2+c1)/(mu1**2+mu2**2+c1)
  conQuality = (2*sig1*sig2+c2)/(sig1**2+sig2**2+c2)
  strQuality = (sig12+c3)/(sig1*sig2+c3)
  ssim = (lumQuality**alpha)*(conQuality**beta)*(strQuality**gamma)
  return -t.mean(ssim)


dis, kh, kw = util.loadKernSrgbGS("badkern.png",0.6)
dis=dis.detach().to("cuda")

imgtxt = t.from_numpy(util.tolum(io.imread("text2.png")))/255
imgtxt = t.clamp(imgtxt+t.randn_like(imgtxt)/1000,0,1)
imgnight = t.from_numpy(util.tolum(io.imread("grayscalenight.jpeg"))[::4,::4])/255
imgdir = t.from_numpy(util.tolum(io.imread("ntcl.jpg")))/255

img = util.inverseSrgb_(imgnight).detach().to(t.float32).to("cuda")
print(img.shape)


def distortwiththing(x):
  return t.nn.functional.conv2d(
    x.unsqueeze(0),weight=dis.unsqueeze(0).unsqueeze(0),
    padding=(kh,kw)
  ).squeeze(0)

plt.style.use('dark_background')


def minDistort(target, distortionfn, lossfn, iters=100, lr=0.01):
  x=target.clone()
  x.requires_grad = True
  xoptiom = t.optim.AdamW([x], lr=lr)
  b=target.clone().detach()
  losses=[]
  for i in tqdm.trange(iters):
    xoptiom.zero_grad()
    x.data=t.clamp(x.data,0,1)
    xoptiom.zero_grad()
    loss = lossfn(distortionfn(x),b)
    loss.backward()
    xoptiom.step()
    loss2 = mse(distortionfn(x),b)
    if i%10 == 0 or i==iters-1:
      losses.append([loss.item(),loss2.item()])
  return x,losses

gaus = util.fspecialgaussian(5,2.5).to("cuda")
bloss  =lambda x,y:mssim_err(util.srgb_(x),util.srgb_(y),gaus,alpha=0.03,gamma=2)

# nim, loss = minDistort(
#   img, distortionfn=distortwiththing, lossfn=bloss,iters=100
#   #img, distortionfn=distortwiththing, lossfn=lambda x,y:mse(x,y),iters=100
# )
# print(loss)
def gssolve(y,iters=100):
  y=t.clamp(y+t.randn_like(y)/1000,0,1)
  im, l = minDistort(y, distortwiththing, bloss, iters)
  print(l)
  return im
def rgbsolve(y, iters=100):
  y=t.clamp(y+t.randn_like(y)/1000,0,1)
  c1,_ = minDistort(y[:,:,0], distortwiththing, bloss, iters)
  c2,_ = minDistort(y[:,:,1], distortwiththing, bloss, iters)
  c3,_ = minDistort(y[:,:,2], distortwiththing, bloss, iters)
  return t.stack([c1,c2,c3],dim=-1)

def rgbDistort(y):
  return t.stack([distortwiththing(y[:,:,i]) for i in range(3)],dim=-1)

#fpath = ["ntcl.jpg","summer-landscape-with-river.jpg","pathnight.jpg","carnival.jpg"][3]
#imgnctl = util.inverseSrgb_(t.from_numpy(io.imread(fpath)).to(t.float32)/255).detach().to("cuda")



# plt.imshow(util.srgb_(distortwiththing(nim)).to("cpu").detach(),cmap="gray")
# plt.show()
from datetime import datetime

# plt.imshow(util.srgb_(nim   ).to("cpu").detach(),cmap="gray")
if __name__ == "__main__":
  # result = util.rgbShow(util.torgb(img),util.torgb(gssolve(img, 100)),distortwiththing)
  # plt.imshow(result,cmap="gray")
  # print(result, result.shape)
  # plt.show()
  # filename = "results/" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".png"
  # plt.imsave(filename,result.numpy())

  img = util.inverseSrgb_((t.from_numpy(io.imread("ntcl.jpg")).to("cuda")/255)*0.9+0.1)
  print(img.shape)
  plt.imshow(util.srgb_(rgbDistort(img)).to("cpu"))
  plt.show()

  # h = img.shape[0]
  # w=img.shape[1]
  # imgs = t.empty((h,w*5)).to("cuda")
  # alphas = [0.01,0.03,0.1,0.3,1]
  # for i in range(5):
  #   im, l = minDistort(
  #     img, distortwiththing, 
  #     lambda x,y:mssim_err(util.srgb_(x),util.srgb_(y),gaus,alpha=alphas[i],gamma=2), iters=80
  #   )
  #   imgs[:,img.shape[1]*i:img.shape[1]*(i+1)] = util.srgb_(distortwiththing(im))
  # plt.imshow(imgs.to("cpu").detach(), cmap="gray")
  # plt.show()



